Skip to content

ggml: adds CONV_2D op and direct GEMM Vulkan implementation #14316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from

Conversation

etasnadi
Copy link
Contributor

This patch adds support for direct computation of 2D convolution on Vulkan backend: it is in a form of a custom GEMM that loads the relevant data from the kernel and input to the shared memory therefore it does not need the materialization of the convolution matrix in the global memory with im2col thus saving lots of memory - similarly how the op implemented in cuDNN. This logic can theoretically result in faster kernels than im2col->matmul because the transfer of the full matrix between GMEM and registers is not needed and the repeating elements for the (virtual) helper matrix can be pulled from L2.

The performance is 2x compared to im2col->matmul on RTX 2060 (2.15 TFLOPS compared to 4.10 TFLOPS according to test-backend-ops theoretical max is ~6 TFLOPS):

$ GGML_VK_DISABLE_COOPMAT=1 ./bin/test-backend-ops -o CONV_2D_INDIRECT_IMPL -b Vulkan0 perf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 2060 SUPER (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 0 | matrix cores: none
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: NVIDIA GeForce RTX 2060 SUPER
  Device memory: 8192 MB (8192 MB free)

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                       16 runs - 64065.69 us/run - 137.42 GFLOP/run -   2.15 TFLOPS
  Backend Vulkan0: OK

Backend 2/2: CPU
  Skipping
2/2 backends passed
OK

$ ./bin/test-backend-ops -o CONV_2D_INDIRECT_IMPL -b Vulkan0 perf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 2060 SUPER (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 0 | matrix cores: KHR_coopmat
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: NVIDIA GeForce RTX 2060 SUPER
  Device memory: 8192 MB (8192 MB free)

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                       46 runs - 21751.26 us/run - 137.42 GFLOP/run -   6.32 TFLOPS
  Backend Vulkan0: OK

Backend 2/2: CPU
  Skipping
2/2 backends passed
OK

$ GGML_VK_DISABLE_COOPMAT=1 ./bin/test-backend-ops -o CONV_2D_DIRECT_IMPL -b Vulkan0 perf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 2060 SUPER (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 0 | matrix cores: none
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: NVIDIA GeForce RTX 2060 SUPER
  Device memory: 8192 MB (8192 MB free)

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 30 runs - 33534.17 us/run - 137.42 GFLOP/run -   4.10 TFLOPS
  Backend Vulkan0: OK

Backend 2/2: CPU
  Skipping
2/2 backends passed
OK

As a negative result, the indirect op is signiticantly faster on a GTX 1060 notebook (1.73 vs 1.21 TFLOPS -- theoretical max is ~3 TFLOPS) might be because blocktile sizes are too big for this older hardware.

The PR also adds support to compare ops with different implementation graphs in test-backend-ops, so one can compare/test the actual (potentially fused and optimized op under development) to a reference op that does not have a direct implementation on CPU yet making op development faster.

@github-actions github-actions bot added testing Everything test related Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Jun 21, 2025
Copy link
Collaborator

@jeffbolznv jeffbolznv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool!

@netrunnereve
Copy link
Collaborator

netrunnereve commented Jun 21, 2025

As a negative result, the indirect op is signiticantly faster on a GTX 1060 notebook (1.73 vs 1.21 TFLOPS -- theoretical max is ~3 TFLOPS) might be because blocktile sizes are too big for this older hardware.

On my RX 470 the indirect op is faster as well. IMO it's worth testing with more input and kernel sizes like what we have for im2col, and the real test to get this set up with stablediffusion.cpp (though that thing hasn't been updated for months) to see how it does with an actual model.

CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                       14 runs - 72823.29 us/run - 137.42 GFLOP/run -   1.89 TFLOPS
CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 11 runs - 96444.18 us/run - 137.42 GFLOP/run -   1.42 TFLOPS

@etasnadi
Copy link
Contributor Author

etasnadi commented Jun 21, 2025

As a negative result, the indirect op is signiticantly faster on a GTX 1060 notebook (1.73 vs 1.21 TFLOPS -- theoretical max is ~3 TFLOPS) might be because blocktile sizes are too big for this older hardware.

On my RX 470 the indirect op is faster as well. IMO it's worth testing with more input and kernel sizes like what we have for im2col, and the real test to get this set up with stablediffusion.cpp (though that thing hasn't been updated for months) to see how it does with an actual model.

CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                       14 runs - 72823.29 us/run - 137.42 GFLOP/run -   1.89 TFLOPS
CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 11 runs - 96444.18 us/run - 137.42 GFLOP/run -   1.42 TFLOPS

Sure, older models might introduce other bottlenecks that causes the shader to slow down but the memory saving still a considerable advantage. I'm thinking about reimplementing the shader in CUDA so I can profile it with Nsight to see what causes the issue (hopefully it still supports ancient cards).

@rmatif
Copy link
Collaborator

rmatif commented Jun 27, 2025

For curiousity I have tested it on a mali GPU

./test-backend-ops -o CONV_2D_DIRECT_IMPL -b Vulkan0 perf                           
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Mali-G715 (Mali-G715) | uma: 1 | fp16: 1 | warp size: 16 | shared memory: 32768 | int dot: 1 | matrix cores: KHR_coopmat
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: Mali-G715
  Device memory: 11229 MB (11229 MB free)

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  1 runs - 9295123.00 us/run - 137.42 GFLOP/run -  14.78 GFLOPS
  Backend Vulkan0: OK

Backend 2/2: CPU
  Skipping
2/2 backends passed
OK

./test-backend-ops -o CONV_2D_INDIRECT_IMPL -b Vulkan0 perf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Mali-G715 (Mali-G715) | uma: 1 | fp16: 1 | warp size: 16 | shared memory: 32768 | int dot: 1 | matrix cores: KHR_coopmat
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: Mali-G715
  Device memory: 11229 MB (11229 MB free)

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                        1 runs - 4126959.00 us/run - 137.42 GFLOP/run -  33.30 GFLOPS
  Backend Vulkan0: OK

Backend 2/2: CPU
  Skipping
2/2 backends passed
OK

@etasnadi
Copy link
Contributor Author

For curiousity I have tested it on a mali GPU

./test-backend-ops -o CONV_2D_DIRECT_IMPL -b Vulkan0 perf                           
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Mali-G715 (Mali-G715) | uma: 1 | fp16: 1 | warp size: 16 | shared memory: 32768 | int dot: 1 | matrix cores: KHR_coopmat
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: Mali-G715
  Device memory: 11229 MB (11229 MB free)

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  1 runs - 9295123.00 us/run - 137.42 GFLOP/run -  14.78 GFLOPS
  Backend Vulkan0: OK

Backend 2/2: CPU
  Skipping
2/2 backends passed
OK

./test-backend-ops -o CONV_2D_INDIRECT_IMPL -b Vulkan0 perf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Mali-G715 (Mali-G715) | uma: 1 | fp16: 1 | warp size: 16 | shared memory: 32768 | int dot: 1 | matrix cores: KHR_coopmat
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: Mali-G715
  Device memory: 11229 MB (11229 MB free)

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                        1 runs - 4126959.00 us/run - 137.42 GFLOP/run -  33.30 GFLOPS
  Backend Vulkan0: OK

Backend 2/2: CPU
  Skipping
2/2 backends passed
OK

Hi, thanks for testing! Please disable coopmats for fair comparison, because my alg is currently fp32 scalar while the indirect is mixed precision that uses matrix cores. Anyway, I already found cases where my alg is slower and I will update it soon.

@rmatif
Copy link
Collaborator

rmatif commented Jun 27, 2025

Hi, thanks for testing! Please disable coopmats for fair comparison, because my alg is currently fp32 scalar while the indirect is mixed precision that uses matrix cores. Anyway, I already found cases where my alg is slower and I will update it soon.

Here without coopmat:

GGML_VK_DISABLE_COOPMAT=1 ./test-backend-ops -o CONV_2D_DIRECT_IMPL -b Vulkan0 perf                                                
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Mali-G715 (Mali-G715) | uma: 1 | fp16: 1 | warp size: 16 | shared memory: 32768 | int dot: 1 | matrix cores: none
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: Mali-G715
  Device memory: 11229 MB (11229 MB free)

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  1 runs - 12637915.00 us/run - 137.42 GFLOP/run -  10.87 GFLOPS
  Backend Vulkan0: OK

Backend 2/2: CPU
  Skipping
2/2 backends passed
OK

GGML_VK_DISABLE_COOPMAT=1 ./test-backend-ops -o CONV_2D_INDIRECT_IMPL -b Vulkan0 perf                                                                        
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Mali-G715 (Mali-G715) | uma: 1 | fp16: 1 | warp size: 16 | shared memory: 32768 | int dot: 1 | matrix cores: none
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: Mali-G715
  Device memory: 11229 MB (11229 MB free)

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0): 0: 0x7b4d6d8378 
1: 0x7b4d6d8284 ggml_print_backtrace
2: 0x7b4d6e8f3c 
3: 0x7b4d74864c 
4: 0x7b4d762c30 __cxa_get_exception_ptr
5: 0x7b4d762c0c 
6: 0x7b4dad988c 
7: 0x7b4dabae48 
8: 0x7b4dab57d0 
9: 0x7b4d6ebb54 ggml_backend_graph_compute
10: 0x5a933e51a8 
11: 0x5a933d7ce8 
12: 0x7b48712218 __libc_init
libc++abi: terminating due to uncaught exception of type vk::DeviceLostError: vk::Device::getFenceStatus: ErrorDeviceLost
Aborted 

For some reason it crashes when running the indirect conv2d without coopmat

@0cc4m
Copy link
Collaborator

0cc4m commented Jun 28, 2025

This is cool. Here are results from my hardware (with coopmat/coopmat2 disabled):

Device TFLOPS Indirect TFLOPS Direct TFLOPS Direct (fe85b44)
Nvidia RTX 3090 12.25 12.06 15.14
AMD Radeon Pro VII 5.65 4.36 5.91
Intel A770 1.65 3.95 6.66

@etasnadi
Copy link
Contributor Author

etasnadi commented Jun 28, 2025

This is cool. Here are results from my hardware (with coopmat/coopmat2 disabled):
Device TFLOPS Indirect TFLOPS Direct TFLOPS Direct (fe85b44)
Nvidia RTX 3090 12.25 12.06 15.14
AMD Radeon Pro VII 5.65 4.36 5.91
Intel A770 1.65 3.95 6.66

Unfortunately this branch contains a logical error so does the code in this pull request so the improvement is smaller. I will update the correct version Today (it still has an edge over the indirect impl on my 2060, so it might be worth it to test). I have not updated it recently because I am working on better shared memory handling because the bank conflicts slow down the kernel too much.

Edit: deleted the branch to prevent further confusions.

@0cc4m
Copy link
Collaborator

0cc4m commented Jun 29, 2025

Fair enough, I'll redo the test once you publish the fixed version.

@etasnadi
Copy link
Contributor Author

Fair enough, I'll redo the test once you publish the fixed version.

@0cc4m I've fixed some trivial errors in my out of tree update: etasnadi@50a29f4 So I am curious how fast it is on your devices.

My experiments show that it is 15% faster than im2col+SGEMM (indirect implementation) on my Pascal device for a large matrix and 40% faster on my Turing desktop GPU on a large problem (4096x4096x4096) while using far less memory (my alg does not store the im2col matrix consuming as much space as batch_size x image_width x image_height x image_channels x kernel_width x kernel_height). The performance improvement is sometimes 2x on smaller problems. I expect that if the problem is large enough the computation speed would converge even if the direct implementation is optimal.

@netrunnereve I added a few test cases for performance measurements of shapes that are common in convolutional neural networks.

I simplified my ifs to minimize branch divergence so I do not need macros anymore (I falsely assumed that the compiler will do this). This made my kernel equally fast to the indirect op on my old device.

The kernel executes many non-const divisions when loading data so this seriously affected the performance on my older device so I added support for collective ops (warp shuffle) to mitigate this issue probably caused by the limited number of SFUs. Such ops were introduced in Kepler, but can be disabled with a macro if we want to support even older hardware.

My code still have serious bank conflicts I chose to not to eliminate yet because the fix would not be compatible with coopmats.

GTX 1060 (Notebook) (Pascal)
============================

$ GGML_VK_DISABLE_COOPMAT=1 GGML_VK_DISABLE_COOPMAT2=1 ./bin/test-backend-ops -o CONV_2D_DIRECT_IMPL -b Vulkan0 perf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce GTX 1060 (NVIDIA) | uma: 0 | fp16: 0 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: none
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: NVIDIA GeForce GTX 1060
  Device memory: 6144 MB (6144 MB free)

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 15 runs - 69691.60 us/run - 137.42 GFLOP/run -   1.97 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    7480 runs -   134.97 us/run - 133.69 MFLOP/run - 990.53 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    5159 runs -   221.34 us/run - 135.78 MFLOP/run - 613.46 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):             24576 runs -    53.20 us/run - 642.82 KFLOP/run -  12.08 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4786 runs -   532.73 us/run -  20.90 MFLOP/run -  39.23 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     8192 runs -   426.01 us/run -   2.78 MFLOP/run -   6.54 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -  3332.21 us/run -  22.28 MFLOP/run -   6.69 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    5202 runs -   193.92 us/run - 115.40 MFLOP/run - 595.10 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     981 runs -  1045.70 us/run - 923.24 MFLOP/run - 882.89 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  880 runs -  1178.65 us/run -   1.85 GFLOP/run -   1.57 TFLOPS
  Backend Vulkan0: OK

Backend 2/2: CPU
  Skipping
2/2 backends passed
OK

$ GGML_VK_DISABLE_COOPMAT=1 GGML_VK_DISABLE_COOPMAT2=1 ./bin/test-backend-ops -o CONV_2D_INDIRECT_IMPL -b Vulkan0 perf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce GTX 1060 (NVIDIA) | uma: 0 | fp16: 0 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: none
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: NVIDIA GeForce GTX 1060
  Device memory: 6144 MB (6144 MB free)

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                       13 runs - 79762.46 us/run - 137.42 GFLOP/run -   1.72 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  4488 runs -   254.70 us/run - 133.69 MFLOP/run - 524.90 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  2948 runs -   387.09 us/run - 135.78 MFLOP/run - 350.78 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   18432 runs -    54.67 us/run - 642.82 KFLOP/run -  11.76 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   2048 runs -   738.82 us/run -  20.90 MFLOP/run -  28.28 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   2048 runs -   593.18 us/run -   2.78 MFLOP/run -   4.69 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   1024 runs -  4524.83 us/run -  22.28 MFLOP/run -   4.92 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  4335 runs -   262.02 us/run - 115.40 MFLOP/run - 440.44 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   545 runs -  1982.49 us/run - 923.24 MFLOP/run - 465.70 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                605 runs -  1717.93 us/run -   1.85 GFLOP/run -   1.08 TFLOPS
  Backend Vulkan0: OK

Backend 2/2: CPU
  Skipping
2/2 backends passed
OK

RTX 2060 super (Turing)
=======================

$ GGML_VK_DISABLE_COOPMAT=1 GGML_VK_DISABLE_COOPMAT2=1 ./bin/test-backend-ops -o CONV_2D_DIRECT_IMPL -b Vulkan0 perf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 2060 SUPER (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 0 | matrix cores: none
register_backend: registered backend Vulkan (1 devices)
register_device: registered device Vulkan0 (NVIDIA GeForce RTX 2060 SUPER)
register_backend: registered backend CPU (1 devices)
register_device: registered device CPU (Intel(R) Core(TM) i5-8400 CPU @ 2.80GHz)
load_backend: failed to find ggml_backend_init in /home/etasnadi/llama.cppxx-vulkan/build/bin/libggml-vulkan.so
load_backend: failed to find ggml_backend_init in /home/etasnadi/llama.cppxx-vulkan/build/bin/libggml-cpu.so
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: NVIDIA GeForce RTX 2060 SUPER
  Device memory: 8192 MB (8192 MB free)

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 34 runs - 30198.44 us/run - 137.42 GFLOP/run -   4.55 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   17952 runs -    57.88 us/run - 133.69 MFLOP/run -   2.31 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   12529 runs -    79.82 us/run - 135.78 MFLOP/run -   1.70 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):             65536 runs -    16.72 us/run - 642.82 KFLOP/run -  38.46 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     9572 runs -   133.05 us/run -  20.90 MFLOP/run - 157.06 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    16384 runs -   100.10 us/run -   2.78 MFLOP/run -  27.82 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -   748.84 us/run -  22.28 MFLOP/run -  29.75 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   12138 runs -    82.51 us/run - 115.40 MFLOP/run -   1.40 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    2289 runs -   437.40 us/run - 923.24 MFLOP/run -   2.11 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 1760 runs -   573.97 us/run -   1.85 GFLOP/run -   3.22 TFLOPS
  Backend Vulkan0: OK

Backend 2/2: CPU
  Skipping
2/2 backends passed
OK

$ GGML_VK_DISABLE_COOPMAT=1 GGML_VK_DISABLE_COOPMAT2=1 ./bin/test-backend-ops -o CONV_2D_INDIRECT_IMPL -b Vulkan0 perf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 2060 SUPER (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 0 | matrix cores: none
register_backend: registered backend Vulkan (1 devices)
register_device: registered device Vulkan0 (NVIDIA GeForce RTX 2060 SUPER)
register_backend: registered backend CPU (1 devices)
register_device: registered device CPU (Intel(R) Core(TM) i5-8400 CPU @ 2.80GHz)
load_backend: failed to find ggml_backend_init in /home/etasnadi/llama.cppxx-vulkan/build/bin/libggml-vulkan.so
load_backend: failed to find ggml_backend_init in /home/etasnadi/llama.cppxx-vulkan/build/bin/libggml-cpu.so
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: NVIDIA GeForce RTX 2060 SUPER
  Device memory: 8192 MB (8192 MB free)

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                       24 runs - 42034.92 us/run - 137.42 GFLOP/run -   3.27 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  8228 runs -   129.57 us/run - 133.69 MFLOP/run -   1.03 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  7370 runs -   142.80 us/run - 135.78 MFLOP/run - 950.89 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   36864 runs -    27.37 us/run - 642.82 KFLOP/run -  23.49 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   3072 runs -   346.84 us/run -  20.90 MFLOP/run -  60.25 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   4096 runs -   272.06 us/run -   2.78 MFLOP/run -  10.24 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   1024 runs -  2064.20 us/run -  22.28 MFLOP/run -  10.79 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 12138 runs -    85.26 us/run - 115.40 MFLOP/run -   1.35 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  2071 runs -   494.73 us/run - 923.24 MFLOP/run -   1.87 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               1210 runs -   864.86 us/run -   1.85 GFLOP/run -   2.14 TFLOPS
  Backend Vulkan0: OK

Backend 2/2: CPU
  Skipping
2/2 backends passed
OK

@0cc4m
Copy link
Collaborator

0cc4m commented Jul 12, 2025

Here are updated values using your new branch:

Device TFLOPS Indirect TFLOPS Direct
Nvidia RTX 3090 12.28 15.19
AMD Radeon Pro VII 5.68 5.83
Intel A770 1.65 5.11

Performance looks good. test-backend-ops -o CONV_2D_DIRECT_IMPL shows correct results on AMD and Nvidia, and a lot of failures on Intel, but that is sadly not unusual and not a reason not to merge.

@etasnadi
Copy link
Contributor Author

Here are updated values using your new branch:
Device TFLOPS Indirect TFLOPS Direct
Nvidia RTX 3090 12.28 15.19
AMD Radeon Pro VII 5.68 5.83
Intel A770 1.65 5.11

Performance looks good. test-backend-ops -o CONV_2D_DIRECT_IMPL shows correct results on AMD and Nvidia, and a lot of failures on Intel, but that is sadly not unusual and not a reason not to merge.

I guess these are the mean flops for all test cases.

IDK the issue with Intel, can you attach a log to see which tests are failing?

@netrunnereve
Copy link
Collaborator

Here are my numbers with etasnadi@50a29f4 on my 470. The last test is running a bit slower but otherwise everything looks good.

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 16 runs - 64554.25 us/run - 137.42 GFLOP/run -   2.13 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    7480 runs -   144.56 us/run - 133.69 MFLOP/run - 924.83 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    5896 runs -   170.54 us/run - 135.78 MFLOP/run - 796.19 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):             40960 runs -    27.68 us/run - 642.82 KFLOP/run -  23.23 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4786 runs -   230.89 us/run -  20.90 MFLOP/run -  90.50 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     8192 runs -   144.69 us/run -   2.78 MFLOP/run -  19.25 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -   999.69 us/run -  22.28 MFLOP/run -  22.29 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    5202 runs -   211.46 us/run - 115.40 MFLOP/run - 545.76 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    1090 runs -   930.41 us/run - 923.24 MFLOP/run - 992.29 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  880 runs -  1149.22 us/run -   1.85 GFLOP/run -   1.61 TFLOPS

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                       14 runs - 73065.07 us/run - 137.42 GFLOP/run -   1.88 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  5984 runs -   183.60 us/run - 133.69 MFLOP/run - 728.18 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  5159 runs -   208.39 us/run - 135.78 MFLOP/run - 651.58 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   15360 runs -    65.95 us/run - 642.82 KFLOP/run -   9.75 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   1024 runs -  1112.95 us/run -  20.90 MFLOP/run -  18.78 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   2048 runs -   595.00 us/run -   2.78 MFLOP/run -   4.68 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   1024 runs -  4403.04 us/run -  22.28 MFLOP/run -   5.06 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  3468 runs -   288.74 us/run - 115.40 MFLOP/run - 399.68 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   763 runs -  1311.67 us/run - 923.24 MFLOP/run - 703.86 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                990 runs -  1061.99 us/run -   1.85 GFLOP/run -   1.74 TFLOPS

@0cc4m did you run your tests using the specific commit etasnadi@50a29f4? The ggml/conv_2d branch is outdated and only has a single test.

etasnadi added 2 commits July 12, 2025 23:49
* ggml-vulkan: adds f32 scalar shader to compute 2D convolution directly
with gemm (no need for im2col),

* test-backend-ops: adds test_case_ref to check the validity/performance of ops
against reference implementations having different graphs, adds tests
  eliminate redundant calculation, macros removed.

* Kernel shared memory size check

* Updates test-backend-ops to support graphs for performance
  measurement.
@0cc4m
Copy link
Collaborator

0cc4m commented Jul 13, 2025

I used the new branch, but only copied the TFLOPS number from the first test, the large one. I made sure the others were improved as well, though.

* Subgroup size used to determine tile size -> fixes llvmpipe errors.
@etasnadi
Copy link
Contributor Author

Here are updated values using your new branch:
Device TFLOPS Indirect TFLOPS Direct
Nvidia RTX 3090 12.28 15.19
AMD Radeon Pro VII 5.68 5.83
Intel A770 1.65 5.11

Performance looks good. test-backend-ops -o CONV_2D_DIRECT_IMPL shows correct results on AMD and Nvidia, and a lot of failures on Intel, but that is sadly not unusual and not a reason not to merge.

7f9b659 might work on Intel too (at least it is now working in llvmpipe as previous coommit that you tested contained a bug dependent on the subgroup size).

@etasnadi
Copy link
Contributor Author

etasnadi commented Jul 14, 2025

@0cc4m I refactored test-backend-ops significantly to support evaluating against composite ops (to be able to compare the result/perf of a single op to the output of a graph) so somebody might want to check it:

  • test_case#eval made virtual,
  • test_case#eval_perf needs to add multiple nodes to a graph when composite ops are being tested. The general logic is not implemented yet the so conv2d is handled in an if temporarily (not sure that my idea to support comparing composite ops is supported by the community),
  • adds test_case_ref to compare the result of a single op to a composite op,

@0cc4m
Copy link
Collaborator

0cc4m commented Jul 14, 2025

Here are full results:

ggml_vulkan: 0 = NVIDIA GeForce RTX 3090 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: none

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                       90 runs - 11142.36 us/run - 137.42 GFLOP/run -  12.33 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 16456 runs -    61.25 us/run - 133.69 MFLOP/run -   2.18 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 16951 runs -    60.02 us/run - 135.78 MFLOP/run -   2.26 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   57344 runs -    17.48 us/run - 642.82 kFLOP/run -  36.78 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   7168 runs -   146.42 us/run -  20.90 MFLOP/run - 142.71 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   9216 runs -   114.07 us/run -   2.78 MFLOP/run -  24.41 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   2048 runs -   832.51 us/run -  22.28 MFLOP/run -  26.76 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 19074 runs -    54.11 us/run - 115.40 MFLOP/run -   2.13 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  4687 runs -   213.66 us/run - 923.24 MFLOP/run -   4.32 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               2970 runs -   341.12 us/run -   1.85 GFLOP/run -   5.42 TFLOPS

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                113 runs -  8925.79 us/run - 137.42 GFLOP/run -  15.40 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   30668 runs -    33.01 us/run - 133.69 MFLOP/run -   4.05 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   30217 runs -    33.20 us/run - 135.78 MFLOP/run -   4.09 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):            106496 runs -     9.75 us/run - 642.82 kFLOP/run -  65.92 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    23930 runs -    50.34 us/run -  20.90 MFLOP/run - 415.07 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    32768 runs -    39.65 us/run -   2.78 MFLOP/run -  70.24 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -   252.37 us/run -  22.28 MFLOP/run -  88.28 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   19074 runs -    53.87 us/run - 115.40 MFLOP/run -   2.14 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    6540 runs -   153.76 us/run - 923.24 MFLOP/run -   6.00 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 4675 runs -   214.96 us/run -   1.85 GFLOP/run -   8.60 TFLOPS

ggml_vulkan: 0 = AMD Radeon (TM) Pro VII (RADV VEGA20) (radv) | uma: 0 | fp16: 1 | warp size: 64 | shared memory: 65536 | int dot: 1 | matrix cores: none

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                       42 runs - 24203.93 us/run - 137.42 GFLOP/run -   5.68 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 11220 runs -    91.08 us/run - 133.69 MFLOP/run -   1.47 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 10318 runs -    97.49 us/run - 135.78 MFLOP/run -   1.39 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   29696 runs -    33.92 us/run - 642.82 kFLOP/run -  18.95 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   3072 runs -   395.17 us/run -  20.90 MFLOP/run -  52.88 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   5120 runs -   231.24 us/run -   2.78 MFLOP/run -  12.04 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   1024 runs -  1873.64 us/run -  22.28 MFLOP/run -  11.89 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  5202 runs -   211.76 us/run - 115.40 MFLOP/run - 544.99 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  1853 runs -   571.89 us/run - 923.24 MFLOP/run -   1.61 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               2145 runs -   477.77 us/run -   1.85 GFLOP/run -   3.87 TFLOPS

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 43 runs - 23415.81 us/run - 137.42 GFLOP/run -   5.87 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   10472 runs -   101.51 us/run - 133.69 MFLOP/run -   1.32 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    8844 runs -   115.42 us/run - 135.78 MFLOP/run -   1.18 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):             81920 runs -    13.17 us/run - 642.82 kFLOP/run -  48.80 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    14358 runs -    86.67 us/run -  20.90 MFLOP/run - 241.11 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    24576 runs -    53.64 us/run -   2.78 MFLOP/run -  51.92 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -   349.93 us/run -  22.28 MFLOP/run -  63.67 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    7803 runs -   129.15 us/run - 115.40 MFLOP/run - 893.60 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    2725 runs -   377.86 us/run - 923.24 MFLOP/run -   2.44 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 1980 runs -   515.94 us/run -   1.85 GFLOP/run -   3.58 TFLOPS

ggml_vulkan: 0 = Intel(R) Arc(tm) A770 Graphics (DG2) (Intel open-source Mesa driver) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 65536 | int dot: 1 | matrix cores: none

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                       13 runs - 83219.23 us/run - 137.42 GFLOP/run -   1.65 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  8976 runs -   116.22 us/run - 133.69 MFLOP/run -   1.15 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  6633 runs -   154.83 us/run - 135.78 MFLOP/run - 876.98 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   34816 runs -    29.03 us/run - 642.82 kFLOP/run -  22.15 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   5120 runs -   214.65 us/run -  20.90 MFLOP/run -  97.35 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   6144 runs -   183.09 us/run -   2.78 MFLOP/run -  15.21 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   1024 runs -  1335.37 us/run -  22.28 MFLOP/run -  16.68 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  6069 runs -   186.28 us/run - 115.40 MFLOP/run - 619.51 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   981 runs -  1095.76 us/run - 923.24 MFLOP/run - 842.55 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                660 runs -  1528.64 us/run -   1.85 GFLOP/run -   1.21 TFLOPS

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  2 runs - 1005052.50 us/run - 137.42 GFLOP/run - 136.73 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     748 runs -  1683.96 us/run - 133.69 MFLOP/run -  79.39 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     737 runs -  2386.06 us/run - 135.78 MFLOP/run -  56.91 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):              8192 runs -   370.88 us/run - 642.82 kFLOP/run -   1.73 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4786 runs -  2993.49 us/run -  20.90 MFLOP/run -   6.98 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     8192 runs -  1623.20 us/run -   2.78 MFLOP/run -   1.72 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -  4458.65 us/run -  22.28 MFLOP/run -   5.00 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     867 runs -  3623.83 us/run - 115.40 MFLOP/run -  31.85 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     109 runs - 10027.13 us/run - 923.24 MFLOP/run -  92.07 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   55 runs - 20718.15 us/run -   1.85 GFLOP/run -  89.24 GFLOPS

The tests now pass on Intel, but performance is terrible. There might be some subgroup size stuff we can try to fix this.

@etasnadi
Copy link
Contributor Author

Here are full results:

ggml_vulkan: 0 = NVIDIA GeForce RTX 3090 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: none

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                       90 runs - 11142.36 us/run - 137.42 GFLOP/run -  12.33 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 16456 runs -    61.25 us/run - 133.69 MFLOP/run -   2.18 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 16951 runs -    60.02 us/run - 135.78 MFLOP/run -   2.26 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   57344 runs -    17.48 us/run - 642.82 kFLOP/run -  36.78 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   7168 runs -   146.42 us/run -  20.90 MFLOP/run - 142.71 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   9216 runs -   114.07 us/run -   2.78 MFLOP/run -  24.41 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   2048 runs -   832.51 us/run -  22.28 MFLOP/run -  26.76 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 19074 runs -    54.11 us/run - 115.40 MFLOP/run -   2.13 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  4687 runs -   213.66 us/run - 923.24 MFLOP/run -   4.32 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               2970 runs -   341.12 us/run -   1.85 GFLOP/run -   5.42 TFLOPS

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                113 runs -  8925.79 us/run - 137.42 GFLOP/run -  15.40 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   30668 runs -    33.01 us/run - 133.69 MFLOP/run -   4.05 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   30217 runs -    33.20 us/run - 135.78 MFLOP/run -   4.09 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):            106496 runs -     9.75 us/run - 642.82 kFLOP/run -  65.92 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    23930 runs -    50.34 us/run -  20.90 MFLOP/run - 415.07 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    32768 runs -    39.65 us/run -   2.78 MFLOP/run -  70.24 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -   252.37 us/run -  22.28 MFLOP/run -  88.28 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   19074 runs -    53.87 us/run - 115.40 MFLOP/run -   2.14 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    6540 runs -   153.76 us/run - 923.24 MFLOP/run -   6.00 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 4675 runs -   214.96 us/run -   1.85 GFLOP/run -   8.60 TFLOPS

ggml_vulkan: 0 = AMD Radeon (TM) Pro VII (RADV VEGA20) (radv) | uma: 0 | fp16: 1 | warp size: 64 | shared memory: 65536 | int dot: 1 | matrix cores: none

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                       42 runs - 24203.93 us/run - 137.42 GFLOP/run -   5.68 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 11220 runs -    91.08 us/run - 133.69 MFLOP/run -   1.47 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 10318 runs -    97.49 us/run - 135.78 MFLOP/run -   1.39 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   29696 runs -    33.92 us/run - 642.82 kFLOP/run -  18.95 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   3072 runs -   395.17 us/run -  20.90 MFLOP/run -  52.88 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   5120 runs -   231.24 us/run -   2.78 MFLOP/run -  12.04 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   1024 runs -  1873.64 us/run -  22.28 MFLOP/run -  11.89 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  5202 runs -   211.76 us/run - 115.40 MFLOP/run - 544.99 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  1853 runs -   571.89 us/run - 923.24 MFLOP/run -   1.61 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               2145 runs -   477.77 us/run -   1.85 GFLOP/run -   3.87 TFLOPS

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 43 runs - 23415.81 us/run - 137.42 GFLOP/run -   5.87 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   10472 runs -   101.51 us/run - 133.69 MFLOP/run -   1.32 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    8844 runs -   115.42 us/run - 135.78 MFLOP/run -   1.18 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):             81920 runs -    13.17 us/run - 642.82 kFLOP/run -  48.80 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    14358 runs -    86.67 us/run -  20.90 MFLOP/run - 241.11 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    24576 runs -    53.64 us/run -   2.78 MFLOP/run -  51.92 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -   349.93 us/run -  22.28 MFLOP/run -  63.67 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    7803 runs -   129.15 us/run - 115.40 MFLOP/run - 893.60 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    2725 runs -   377.86 us/run - 923.24 MFLOP/run -   2.44 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 1980 runs -   515.94 us/run -   1.85 GFLOP/run -   3.58 TFLOPS

ggml_vulkan: 0 = Intel(R) Arc(tm) A770 Graphics (DG2) (Intel open-source Mesa driver) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 65536 | int dot: 1 | matrix cores: none

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                       13 runs - 83219.23 us/run - 137.42 GFLOP/run -   1.65 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  8976 runs -   116.22 us/run - 133.69 MFLOP/run -   1.15 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  6633 runs -   154.83 us/run - 135.78 MFLOP/run - 876.98 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   34816 runs -    29.03 us/run - 642.82 kFLOP/run -  22.15 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   5120 runs -   214.65 us/run -  20.90 MFLOP/run -  97.35 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   6144 runs -   183.09 us/run -   2.78 MFLOP/run -  15.21 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   1024 runs -  1335.37 us/run -  22.28 MFLOP/run -  16.68 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  6069 runs -   186.28 us/run - 115.40 MFLOP/run - 619.51 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   981 runs -  1095.76 us/run - 923.24 MFLOP/run - 842.55 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                660 runs -  1528.64 us/run -   1.85 GFLOP/run -   1.21 TFLOPS

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  2 runs - 1005052.50 us/run - 137.42 GFLOP/run - 136.73 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     748 runs -  1683.96 us/run - 133.69 MFLOP/run -  79.39 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     737 runs -  2386.06 us/run - 135.78 MFLOP/run -  56.91 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):              8192 runs -   370.88 us/run - 642.82 kFLOP/run -   1.73 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4786 runs -  2993.49 us/run -  20.90 MFLOP/run -   6.98 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     8192 runs -  1623.20 us/run -   2.78 MFLOP/run -   1.72 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -  4458.65 us/run -  22.28 MFLOP/run -   5.00 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     867 runs -  3623.83 us/run - 115.40 MFLOP/run -  31.85 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     109 runs - 10027.13 us/run - 923.24 MFLOP/run -  92.07 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   55 runs - 20718.15 us/run -   1.85 GFLOP/run -  89.24 GFLOPS

The tests now pass on Intel, but performance is terrible. There might be some subgroup size stuff we can try to fix this.

The new commit a09e8f5 disables subgroup ops completely by default (can be enabled with GGML_VK_USE_COLLECTIVES), so we can see if it hurts Intel.

@etasnadi
Copy link
Contributor Author

@0cc4m Please also report the string " --> BS_CRS=%d use_collectives=%d" printed to the stderr on Intel, that might be useful to be sure that subgroup sizes are properly configured.

@0cc4m
Copy link
Collaborator

0cc4m commented Jul 15, 2025

 --> BS_CRS=16 use_collectives=0
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  2 runs - 970593.00 us/run - 137.42 GFLOP/run - 141.59 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     748 runs -  1617.97 us/run - 133.69 MFLOP/run -  82.63 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     737 runs -  2296.76 us/run - 135.78 MFLOP/run -  59.12 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):              8192 runs -   345.04 us/run - 642.82 kFLOP/run -   1.86 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4786 runs -  2837.86 us/run -  20.90 MFLOP/run -   7.36 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     8192 runs -  1528.24 us/run -   2.78 MFLOP/run -   1.82 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -  4459.15 us/run -  22.28 MFLOP/run -   5.00 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     867 runs -  3472.66 us/run - 115.40 MFLOP/run -  33.23 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     109 runs - 10318.41 us/run - 923.24 MFLOP/run -  89.47 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   55 runs - 19954.25 us/run -   1.85 GFLOP/run -  92.66 GFLOPS
 --> BS_CRS=16 use_collectives=1
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  2 runs - 997085.50 us/run - 137.42 GFLOP/run - 137.82 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     748 runs -  1696.24 us/run - 133.69 MFLOP/run -  78.82 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     737 runs -  2384.99 us/run - 135.78 MFLOP/run -  56.93 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):              8192 runs -   371.66 us/run - 642.82 kFLOP/run -   1.73 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4786 runs -  2997.28 us/run -  20.90 MFLOP/run -   6.97 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     8192 runs -  1624.10 us/run -   2.78 MFLOP/run -   1.71 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -  4458.56 us/run -  22.28 MFLOP/run -   5.00 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     867 runs -  3623.94 us/run - 115.40 MFLOP/run -  31.85 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     109 runs - 10017.94 us/run - 923.24 MFLOP/run -  92.16 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   55 runs - 20687.73 us/run -   1.85 GFLOP/run -  89.37 GFLOPS

What does help is forcing the subgroup size to BS_CRS=16 on Intel:

diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 2eb7415c5..e1bba0f84 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -3054,10 +3054,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
             conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
         }
     }
-
+
     std::cerr << " --> BS_CRS=" << conv2d_BS_CRS << " use_collectives=" << use_collectives << std::endl;

-    if(device->subgroup_shuffle){
+    if(device->subgroup_shuffle && device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16){
+        ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives}, 1, true, true, 16);
+    }else if(device->subgroup_shuffle){
         ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives}, 1, true, true);
     }else{
         ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives}, 1, true);
 --> BS_CRS=16 use_collectives=1
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 10 runs - 103033.90 us/run - 137.42 GFLOP/run -   1.33 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    2992 runs -   369.98 us/run - 133.69 MFLOP/run - 361.36 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    2948 runs -   389.91 us/run - 135.78 MFLOP/run - 348.24 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):             24576 runs -    59.38 us/run - 642.82 kFLOP/run -  10.82 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4786 runs -   360.18 us/run -  20.90 MFLOP/run -  58.02 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     8192 runs -   202.12 us/run -   2.78 MFLOP/run -  13.78 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -  1260.09 us/run -  22.28 MFLOP/run -  17.68 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    1734 runs -   780.97 us/run - 115.40 MFLOP/run - 147.77 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     763 runs -  1517.70 us/run - 923.24 MFLOP/run - 608.31 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  330 runs -  3327.69 us/run -   1.85 GFLOP/run - 555.61 GFLOPS

But this is still worse than indirect and also than your (incorrect) earlier attempt. If you can think of something we can give it a shot, but if not it's fine.

@etasnadi
Copy link
Contributor Author

etasnadi commented Jul 15, 2025

 --> BS_CRS=16 use_collectives=0
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  2 runs - 970593.00 us/run - 137.42 GFLOP/run - 141.59 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     748 runs -  1617.97 us/run - 133.69 MFLOP/run -  82.63 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     737 runs -  2296.76 us/run - 135.78 MFLOP/run -  59.12 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):              8192 runs -   345.04 us/run - 642.82 kFLOP/run -   1.86 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4786 runs -  2837.86 us/run -  20.90 MFLOP/run -   7.36 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     8192 runs -  1528.24 us/run -   2.78 MFLOP/run -   1.82 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -  4459.15 us/run -  22.28 MFLOP/run -   5.00 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     867 runs -  3472.66 us/run - 115.40 MFLOP/run -  33.23 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     109 runs - 10318.41 us/run - 923.24 MFLOP/run -  89.47 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   55 runs - 19954.25 us/run -   1.85 GFLOP/run -  92.66 GFLOPS
 --> BS_CRS=16 use_collectives=1
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  2 runs - 997085.50 us/run - 137.42 GFLOP/run - 137.82 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     748 runs -  1696.24 us/run - 133.69 MFLOP/run -  78.82 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     737 runs -  2384.99 us/run - 135.78 MFLOP/run -  56.93 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):              8192 runs -   371.66 us/run - 642.82 kFLOP/run -   1.73 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4786 runs -  2997.28 us/run -  20.90 MFLOP/run -   6.97 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     8192 runs -  1624.10 us/run -   2.78 MFLOP/run -   1.71 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -  4458.56 us/run -  22.28 MFLOP/run -   5.00 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     867 runs -  3623.94 us/run - 115.40 MFLOP/run -  31.85 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     109 runs - 10017.94 us/run - 923.24 MFLOP/run -  92.16 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   55 runs - 20687.73 us/run -   1.85 GFLOP/run -  89.37 GFLOPS

What does help is forcing the subgroup size to BS_CRS=16 on Intel:

diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 2eb7415c5..e1bba0f84 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -3054,10 +3054,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
             conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
         }
     }
-
+
     std::cerr << " --> BS_CRS=" << conv2d_BS_CRS << " use_collectives=" << use_collectives << std::endl;

-    if(device->subgroup_shuffle){
+    if(device->subgroup_shuffle && device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16){
+        ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives}, 1, true, true, 16);
+    }else if(device->subgroup_shuffle){
         ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives}, 1, true, true);
     }else{
         ggml_vk_create_pipeline(device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, sizeof(vk_op_conv2d_push_constants), {conv2d_BS_K, conv2d_BS_NPQ, 1}, {conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives}, 1, true);
 --> BS_CRS=16 use_collectives=1
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 10 runs - 103033.90 us/run - 137.42 GFLOP/run -   1.33 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    2992 runs -   369.98 us/run - 133.69 MFLOP/run - 361.36 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    2948 runs -   389.91 us/run - 135.78 MFLOP/run - 348.24 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):             24576 runs -    59.38 us/run - 642.82 kFLOP/run -  10.82 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4786 runs -   360.18 us/run -  20.90 MFLOP/run -  58.02 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     8192 runs -   202.12 us/run -   2.78 MFLOP/run -  13.78 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -  1260.09 us/run -  22.28 MFLOP/run -  17.68 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    1734 runs -   780.97 us/run - 115.40 MFLOP/run - 147.77 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     763 runs -  1517.70 us/run - 923.24 MFLOP/run - 608.31 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  330 runs -  3327.69 us/run -   1.85 GFLOP/run - 555.61 GFLOPS

But this is still worse than indirect and also than your (incorrect) earlier attempt. If you can think of something we can give it a shot, but if not it's fine.

Might be one more shot if you are still motivated.

I have an Intel(R) HD Graphics 630 (KBL GT2) device and assume there are similarities in the driver behavior so did some debugging. It seems that on this device

  • if collective ops are used and require_full_subgroups is false, the kernel gives incorrect results,
  • if collective ops are used and require_full_subgroups is true, the results are good but performance is bad (very bad if subgroup size is default, and simply bad if it is set to 16 according to your test).

What we have not tried yet is to disable collectives and set require_full_subgroups to false. My integrated Intel device passes the tests with this setting so I am curious what's the performance on A770.

Can you set use_collectives=0 and call the pipeline creation with require_full_subgroups=0? (I only added collectives to mitigate the low SFU throughput on older hardware so it might not help A770 anyways.)

Please use a09e8f5 as the most recent version disables Intel support.

@0cc4m
Copy link
Collaborator

0cc4m commented Jul 15, 2025

Yeah, when not forcing full subgroups and collectives, it works correctly and is fast:

ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Intel(R) Arc(tm) A770 Graphics (DG2) (Intel open-source Mesa driver) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 65536 | int dot: 1 | matrix cores: none
Testing 2 devices

 --> BS_CRS=16 use_collectives=0
Backend 1/2: Vulkan0
  Device description: Intel(R) Arc(tm) A770 Graphics (DG2)
  Device memory: 16032 MB (16032 MB free)

 --> BS_CRS=16 use_collectives=0
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 42 runs - 24048.05 us/run - 137.42 GFLOP/run -   5.71 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   14212 runs -    72.92 us/run - 133.69 MFLOP/run -   1.83 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   13266 runs -    78.70 us/run - 135.78 MFLOP/run -   1.73 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):             81920 runs -    13.00 us/run - 642.82 kFLOP/run -  49.43 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    14358 runs -    76.51 us/run -  20.90 MFLOP/run - 273.11 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    24576 runs -    45.81 us/run -   2.78 MFLOP/run -  60.80 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -   294.73 us/run -  22.28 MFLOP/run -  75.59 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    8670 runs -   125.63 us/run - 115.40 MFLOP/run - 918.58 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    3052 runs -   329.87 us/run - 923.24 MFLOP/run -   2.80 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 1925 runs -   532.86 us/run -   1.85 GFLOP/run -   3.47 TFLOPS

But we still want to use them on Nvidia and AMD since they make a measurable positive difference there.

As a side note, you might have triggered a MoltenVK shader compiler bug:

ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Apple M4 Max (MoltenVK) | uma: 1 | fp16: 1 | warp size: 32 | shared memory: 32768 | int dot: 0 | matrix cores: none
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: Apple M4 Max
  Device memory: 65536 MB (65536 MB free)

ggml_vulkan: Compute pipeline creation failed for conv2d_f32
ggml_vulkan: vk::Device::createComputePipeline: ErrorInitializationFailed
[1]    10651 segmentation fault  build_vk/bin/test-backend-ops -o CONV_2D_DIRECT_IMPL

@etasnadi
Copy link
Contributor Author

Yeah, when not forcing full subgroups and collectives, it works correctly and is fast:

ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Intel(R) Arc(tm) A770 Graphics (DG2) (Intel open-source Mesa driver) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 65536 | int dot: 1 | matrix cores: none
Testing 2 devices

 --> BS_CRS=16 use_collectives=0
Backend 1/2: Vulkan0
  Device description: Intel(R) Arc(tm) A770 Graphics (DG2)
  Device memory: 16032 MB (16032 MB free)

 --> BS_CRS=16 use_collectives=0
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 42 runs - 24048.05 us/run - 137.42 GFLOP/run -   5.71 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   14212 runs -    72.92 us/run - 133.69 MFLOP/run -   1.83 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   13266 runs -    78.70 us/run - 135.78 MFLOP/run -   1.73 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):             81920 runs -    13.00 us/run - 642.82 kFLOP/run -  49.43 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    14358 runs -    76.51 us/run -  20.90 MFLOP/run - 273.11 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    24576 runs -    45.81 us/run -   2.78 MFLOP/run -  60.80 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -   294.73 us/run -  22.28 MFLOP/run -  75.59 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    8670 runs -   125.63 us/run - 115.40 MFLOP/run - 918.58 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    3052 runs -   329.87 us/run - 923.24 MFLOP/run -   2.80 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 1925 runs -   532.86 us/run -   1.85 GFLOP/run -   3.47 TFLOPS

But we still want to use them on Nvidia and AMD since they make a measurable positive difference there.

As a side note, you might have triggered a MoltenVK shader compiler bug:

ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Apple M4 Max (MoltenVK) | uma: 1 | fp16: 1 | warp size: 32 | shared memory: 32768 | int dot: 0 | matrix cores: none
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: Apple M4 Max
  Device memory: 65536 MB (65536 MB free)

ggml_vulkan: Compute pipeline creation failed for conv2d_f32
ggml_vulkan: vk::Device::createComputePipeline: ErrorInitializationFailed
[1]    10651 segmentation fault  build_vk/bin/test-backend-ops -o CONV_2D_DIRECT_IMPL

Great, I did not expect the 2.5-3x speedup (and correct output at the same time). No problem with the collectives because we disable them for Intel and enable it otherwise in a new commit.

The apple bug is entirely new info for me, does not it have a Github pipeline to check which versions of my code have crashed? I do not have access to any apple devices currently and this is probably not enough info to locate the error. Debug compile mode execution log and validation layers output might show more info to better locate the error.

All other kernels pass the tests on MoltenVK? Could you please tweak the usual hotspots e.g enable/disable the shuffle op and full_subgroups? If the problem is with the subgroups then I might do something illegal in the kernel that does not pop up on Nvidia/AMD or both MoltenVK and Intel driver have some bug.

Thanks!

@0cc4m
Copy link
Collaborator

0cc4m commented Jul 15, 2025

It happens regardless of the collectives or full_subgroups setting. Something else is triggering it. Some other shaders are failing to output correct results on Apple, but your conv2d is the only one that crashes on build currently. Without the hardware there's nothing you can do to debug it.

I'll try to figure out how to get MoltenVK to provide more info, but it's not overly important. Apple users should be using Metal in most cases, I only know some niche docker/VM cases where that is not possible, but Vulkan is. I just tried it out of curiosity.

@etasnadi
Copy link
Contributor Author

It happens regardless of the collectives or full_subgroups setting. Something else is triggering it. Some other shaders are failing to output correct results on Apple, but your conv2d is the only one that crashes on build currently. Without the hardware there's nothing you can do to debug it.

I'll try to figure out how to get MoltenVK to provide more info, but it's not overly important. Apple users should be using Metal in most cases, I only know some niche docker/VM cases where that is not possible, but Vulkan is. I just tried it out of curiosity.

I guess you mean pipeline creation on build don't you? One guess: my pipeline configures the workgroup size using a spec constant (local_size_x_id) that might not be common in other kernels. You could try to remove that and setting the local size manually to 255 to see if that's the root cause of the error.

@0cc4m
Copy link
Collaborator

0cc4m commented Jul 15, 2025

It crashes on pipeline creation during runtime, which is the final compile step from SPIR-V to device-specific code, in this case SPIR-V to Metal.

@etasnadi
Copy link
Contributor Author

It crashes on pipeline creation during runtime, which is the final compile step from SPIR-V to device-specific code, in this case SPIR-V to Metal.

Then there are chances that local_size_x_id is actually the problem.

@etasnadi
Copy link
Contributor Author

The last commit a672803 should be fast on all tested archs, but might need a test whether the op is successfully disabled on Apple and collectives are turned off on Intel.

@0cc4m
Copy link
Collaborator

0cc4m commented Jul 16, 2025

Apple looks correct:

ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Apple M4 Max (MoltenVK) | uma: 1 | fp16: 1 | warp size: 32 | shared memory: 32768 | int dot: 0 | matrix cores: none
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: Apple M4 Max
  Device memory: 65536 MB (65536 MB free)

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                       49 runs - 20718.57 us/run - 137.42 GFLOP/run -   6.63 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 13464 runs -    76.03 us/run - 133.69 MFLOP/run -   1.76 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 11792 runs -    85.13 us/run - 135.78 MFLOP/run -   1.59 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   27648 runs -    36.46 us/run - 642.82 kFLOP/run -  17.63 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   5120 runs -   202.51 us/run -  20.90 MFLOP/run - 103.19 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   6144 runs -   177.24 us/run -   2.78 MFLOP/run -  15.71 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   1024 runs -  1254.15 us/run -  22.28 MFLOP/run -  17.76 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  8670 runs -   118.96 us/run - 115.40 MFLOP/run - 970.14 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  2616 runs -   385.87 us/run - 923.24 MFLOP/run -   2.39 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               2255 runs -   450.08 us/run -   1.85 GFLOP/run -   4.11 TFLOPS

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0): not supported
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0): not supported
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0): not supported
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0): not supported
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0): not supported
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0): not supported
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0): not supported
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0): not supported
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0): not supported
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0): not supported

AMD:

ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = AMD Radeon (TM) Pro VII (RADV VEGA20) (radv) | uma: 0 | fp16: 1 | warp size: 64 | shared memory: 65536 | int dot: 1 | matrix cores: none
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: AMD Radeon (TM) Pro VII (RADV VEGA20)
  Device memory: 16368 MB (16368 MB free)

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 44 runs - 23151.07 us/run - 137.42 GFLOP/run -   5.94 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   10472 runs -    99.24 us/run - 133.69 MFLOP/run -   1.35 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    9581 runs -   113.03 us/run - 135.78 MFLOP/run -   1.20 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):             81920 runs -    12.80 us/run - 642.82 kFLOP/run -  50.21 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    14358 runs -    85.49 us/run -  20.90 MFLOP/run - 244.44 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    24576 runs -    52.37 us/run -   2.78 MFLOP/run -  53.18 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -   343.76 us/run -  22.28 MFLOP/run -  64.81 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    8670 runs -   127.73 us/run - 115.40 MFLOP/run - 903.51 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    2725 runs -   374.46 us/run - 923.24 MFLOP/run -   2.47 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 1980 runs -   512.14 us/run -   1.85 GFLOP/run -   3.61 TFLOPS

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                       42 runs - 23839.62 us/run - 137.42 GFLOP/run -   5.76 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 11220 runs -    91.27 us/run - 133.69 MFLOP/run -   1.46 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 11055 runs -    94.82 us/run - 135.78 MFLOP/run -   1.43 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   29696 runs -    34.01 us/run - 642.82 kFLOP/run -  18.90 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   3072 runs -   402.20 us/run -  20.90 MFLOP/run -  51.95 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   5120 runs -   236.43 us/run -   2.78 MFLOP/run -  11.78 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   1024 runs -  1915.74 us/run -  22.28 MFLOP/run -  11.63 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  5202 runs -   207.85 us/run - 115.40 MFLOP/run - 555.22 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  1853 runs -   564.74 us/run - 923.24 MFLOP/run -   1.63 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               2200 runs -   464.67 us/run -   1.85 GFLOP/run -   3.98 TFLOPS

Nvidia too:

ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 3090 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: none
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: NVIDIA GeForce RTX 3090
  Device memory: 24576 MB (24576 MB free)

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                113 runs -  8918.99 us/run - 137.42 GFLOP/run -  15.41 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   30668 runs -    32.86 us/run - 133.69 MFLOP/run -   4.07 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   30954 runs -    33.03 us/run - 135.78 MFLOP/run -   4.11 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):            106496 runs -     9.68 us/run - 642.82 kFLOP/run -  66.38 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    23930 runs -    49.87 us/run -  20.90 MFLOP/run - 418.99 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    32768 runs -    39.15 us/run -   2.78 MFLOP/run -  71.14 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -   251.25 us/run -  22.28 MFLOP/run -  88.67 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   19074 runs -    53.60 us/run - 115.40 MFLOP/run -   2.15 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    6649 runs -   152.20 us/run - 923.24 MFLOP/run -   6.07 TFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 4730 runs -   213.39 us/run -   1.85 GFLOP/run -   8.66 TFLOPS

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                       90 runs - 11146.23 us/run - 137.42 GFLOP/run -  12.33 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 16456 runs -    61.21 us/run - 133.69 MFLOP/run -   2.18 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 16951 runs -    60.21 us/run - 135.78 MFLOP/run -   2.25 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   58368 runs -    17.32 us/run - 642.82 kFLOP/run -  37.11 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   7168 runs -   146.07 us/run -  20.90 MFLOP/run - 143.06 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   9216 runs -   113.38 us/run -   2.78 MFLOP/run -  24.56 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   2048 runs -   823.10 us/run -  22.28 MFLOP/run -  27.07 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 19074 runs -    53.11 us/run - 115.40 MFLOP/run -   2.17 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  4796 runs -   212.96 us/run - 923.24 MFLOP/run -   4.34 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               2915 runs -   343.63 us/run -   1.85 GFLOP/run -   5.38 TFLOPS

However, when you reenable coopmat:

ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 3090 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: KHR_coopmat
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: NVIDIA GeForce RTX 3090
  Device memory: 24576 MB (24576 MB free)

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                      195 runs -  5148.51 us/run - 137.42 GFLOP/run -  26.69 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 24684 runs -    41.55 us/run - 133.69 MFLOP/run -   3.22 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 25058 runs -    40.93 us/run - 135.78 MFLOP/run -   3.32 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   58368 runs -    17.32 us/run - 642.82 kFLOP/run -  37.11 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   7168 runs -   146.29 us/run -  20.90 MFLOP/run - 142.84 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   9216 runs -   114.00 us/run -   2.78 MFLOP/run -  24.43 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   2048 runs -   827.91 us/run -  22.28 MFLOP/run -  26.91 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 23409 runs -    43.09 us/run - 115.40 MFLOP/run -   2.68 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  5668 runs -   178.67 us/run - 923.24 MFLOP/run -   5.17 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               4620 runs -   217.94 us/run -   1.85 GFLOP/run -   8.48 TFLOPS

And coopmat2:

ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 3090 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: NVIDIA GeForce RTX 3090
  Device memory: 24576 MB (24576 MB free)

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                      354 runs -  2828.58 us/run - 137.42 GFLOP/run -  48.58 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 25432 runs -    39.87 us/run - 133.69 MFLOP/run -   3.35 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 23584 runs -    42.81 us/run - 135.78 MFLOP/run -   3.17 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   57344 runs -    17.56 us/run - 642.82 kFLOP/run -  36.60 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   7168 runs -   147.17 us/run -  20.90 MFLOP/run - 141.99 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   9216 runs -   113.84 us/run -   2.78 MFLOP/run -  24.46 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   2048 runs -   826.81 us/run -  22.28 MFLOP/run -  26.95 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 26010 runs -    39.53 us/run - 115.40 MFLOP/run -   2.92 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  4796 runs -   209.08 us/run - 923.24 MFLOP/run -   4.42 TFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               6325 runs -   158.53 us/run -   1.85 GFLOP/run -  11.66 TFLOPS

Is the direct path disabled already when coopmat or coopmat2 are available?

@etasnadi
Copy link
Contributor Author

etasnadi commented Jul 16, 2025

However, when you reenable coopmat:

What I need to see here?

The CONV_2D op itself does not support coopmats yet, and we do not have a different implementation for that op currently, so does not need to disable anything.

The CONV_2D_INDIRECT_IMPL and CONV_2D_DIRECT_IMPL are only aliases exist in test-backend-ops, not real ops. CONV_2D_DIRECT_IMPL executes a graph where the CONV_2D op is responsible for calculating the convolution while CONV_2D_INDIRECT_IMPL executes a graph where IM2_COL is used to calculate the convolution.

@0cc4m
Copy link
Collaborator

0cc4m commented Jul 16, 2025

I hadn't looked at the op details yet. In that case, when we start using the CONV_2D op in models both direct and indirect options should be possible.

@etasnadi
Copy link
Contributor Author

I hadn't looked at the op details yet. In that case, when we start using the CONV_2D op in models both direct and indirect options should be possible.

Yes. There are advantages of the im2col version too: if a highly optimized linalg library is available that is already optimized for each single device (as it is the case in the Cuda backend with cuBLAS), then the indirect op can be more competitive at the tradeoff of wasting memory.

@etasnadi
Copy link
Contributor Author

@0cc4m Do I need to add anything to this PR to get approved? One pipeline fails sometimes but it does not seem to be affected by this PR.

@0cc4m
Copy link
Collaborator

0cc4m commented Jul 18, 2025

I'll do my review this weekend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning testing Everything test related Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants